package bjtuoj;

import java.util.Arrays;
import java.util.Scanner;

/**
 * ClassName: Homework_4_E
 * Description:
 * date: 2021-12-09 15:15
 *
 * @author liyifan
 */
public class Homework_4_E {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        Pair[] pairs = new Pair[n + 1];
        int[] punish = new int[n + 1];
        for (int i = 1;i <= n;i++) {
            pairs[i] = new Pair();
            pairs[i].finish = scanner.nextInt();
        }
        for (int i = 1;i <= n;i++) {
            pairs[i].punish = scanner.nextInt();
        }
        Arrays.sort(pairs, 1, n + 1,(p1, p2) -> p2.punish - p1.punish);
        for (int i = 1;i <= n;i++) {
            punish[i] = i;
        }
        int res = 0;
        for (int i = 1;i <= n;i++)
        {
            int f = find(punish, pairs[i].finish);
            if (f == 0) res += pairs[i].punish;
            else punish[f] = f - 1;
        }
        System.out.printf("%d", res);
    }

    static int find(int[] punish, int i){
        if (i == punish[i]) {
            return i;
        }
        else {
            punish[i] = find(punish, punish[i]);
            return punish[i];
        }
    }
}

class Pair{
    int finish;
    int punish;
}
